import os
import shutil

import numpy as np
from PIL import Image
from keras.applications.imagenet_utils import preprocess_input
from keras.models import load_model
from keras.preprocessing import image
from keras.preprocessing.image import ImageDataGenerator
from keras.utils import np_utils
from scipy.misc import imread

import dao

width = height = 224


def merge_images():
    INPUT_PATH = 'D:/cdn1/desc'
    COMMENT_PATH = 'D:/cdn1/comment'
    DESC_BASE_PATH = 'D:/cdn1/desc'
    Dest = 'E:/cdn/datasets/triplet'
    DESC_SUB_DIR = [path for path in os.listdir(DESC_BASE_PATH) if os.path.isdir(os.path.join(DESC_BASE_PATH, path))]
    print(DESC_SUB_DIR)
    already = [path for path in os.listdir(Dest)]
    print(len(already))
    num = 0
    # package = []
    # for _, _, files in os.walk('D:/cdn/unopen'):
    #     package += files
    #     break
    if not os.path.exists(Dest):
        os.mkdir(Dest)

    is_root = True
    for dir, sub_dir, sub_files in os.walk(COMMENT_PATH):
        #
        if is_root:
            is_root = False
        # 没到图片所在目录
        if len(sub_dir) > 0:
            continue
        # print(len(sub_files),dir)

        # 如果评论图片至少5张
        if len(sub_files) > 4:
            dir_name = dir.split('\\')[-1]
            if dir_name in already:
                continue
            for path in DESC_SUB_DIR:
                desc_dir = os.path.join(DESC_BASE_PATH, path, '2', dir_name)
                if os.path.exists(desc_dir):
                    break
            if not os.path.exists(desc_dir):
                continue
            # 复制评论文件夹
            print(len(sub_files), dir)
            Dest_comment_dir = os.path.join(Dest, 'comment', dir_name)
            shutil.copytree(dir, Dest_comment_dir)
            # 复制详情页文件夹
            Dest_desc_dir = os.path.join(Dest, 'desc', dir_name)
            shutil.copytree(desc_dir, Dest_desc_dir)
            # dest_dir = os.path.join(Dest, dir_name)
            # if not os.path.exists(dest_dir):
            #     os.mkdir(dest_dir)
            # for file_name in sub_files:
            #     file = os.path.join(dir, file_name)
            #     print(file)
            #     img = Image.open(file)
            #     new_image = img.resize((width, height), Image.BILINEAR)
            #     dest_file = os.path.join(dest_dir, file_name)
            #     new_image.save(dest_file)
            num += 1
    print(num)


def merge_img_other(catalog=93):
    COMMENT_PATH = 'D:/cdn1/comment'
    DESC_BASE_PATH = 'D:/cdn1/desc'
    Dest = 'E:/cdn/datasets/triplet'
    already = os.listdir(Dest)
    other_comments = os.listdir(COMMENT_PATH)
    other_comments = list(set(other_comments).difference(already))
    DESC_SUB_DIR = [path for path in os.listdir(DESC_BASE_PATH) if os.path.isdir(os.path.join(DESC_BASE_PATH, path))]
    print(DESC_SUB_DIR)
    sql = "select sourceProductID from t_product where catalogID=%d" % catalog
    results = dao.select_list(sql)
    catalog_goods = []
    for result in results:
        catalog_goods.append(result['sourceProductID'])
    for good_id in other_comments:
        if good_id in catalog_goods:
            comment_good_path = os.path.join(COMMENT_PATH, good_id)
            if len(os.listdir(comment_good_path)) > 0:
                for path in DESC_SUB_DIR:
                    desc_dir = os.path.join(DESC_BASE_PATH, path, '2', good_id)
                    if os.path.exists(desc_dir):
                        break
                if not os.path.exists(desc_dir):
                    continue

                Dest_comment_dir = os.path.join(Dest, 'comment', good_id)
                shutil.copytree(comment_good_path, Dest_comment_dir)
                # 复制详情页文件夹
                # desc_dir = os.path.join(DESC_BASE_PATH, good_id)
                Dest_desc_dir = os.path.join(Dest, 'desc', good_id)
                shutil.copytree(desc_dir, Dest_desc_dir)
                # print(other_comments)


def resize_image(Dest):
    is_root = True
    for dir, sub_dir, sub_files in os.walk(Dest):
        #
        if is_root:
            is_root = False

        # print(cur_path_name)
        # if len(sub_files) > 10:
        #     cur_path_name = dir.split('\\')[-1]
        #     dest_dir = os.path.join(Dest, cur_path_name)
        #     if not os.path.exists(dest_dir):
        #         os.mkdir(dest_dir)
        for file_name in sub_files:
            print(dir, sub_dir, file_name)
            file = os.path.join(dir, file_name)
            img = Image.open(file)
            if img.size == (width, height):
                continue
            new_image = img.resize((width, height), Image.BILINEAR)
            os.remove(file)

            dest_file = os.path.join(dir, file_name)
            new_image.save(dest_file)


def filter_image(img_base, selected_path, move_to):
    already = [path for path in os.listdir(move_to)]
    selected_images = [path for path in os.listdir(selected_path)]
    for dir, sub_dir, names in os.walk(img_base):
        if len(sub_dir) > 0:
            continue
        removed = False
        cur_dir_name = dir.split('\\')[-1]
        for name in names:
            if name in selected_images:
                os.remove(os.path.join(dir, name))
                removed = True
        if removed:
            shutil.move(dir, move_to)


def print_img_info(img_path):
    for good_id in os.listdir(img_path):
        sql = "select name,catalogID,url from t_product where sourceProductID='%s'" % good_id
        result = dao.select_one(sql)
        print(good_id, result['catalogID'], result['name'], result['url'])


def get_img_by_catalog(img_path, dest, catalogId=95):
    sql = "select sourceProductID,productHTML from t_product where id >31041 and  catalogID=%d" % catalogId
    results = dao.select_list(sql)
    for res in results:
        good_id = res['sourceProductID']
        comments_path = os.path.join(img_path, 'comment', good_id)
        desc_path = os.path.join(img_path, res['productHTML'][:15], good_id)
        if os.path.exists(comments_path) and os.path.exists(desc_path) and len(os.listdir(comments_path)) > 5:
            print(comments_path)
            dest_comment = os.path.join(dest, 'comment', good_id)
            if not os.path.exists(dest_comment):
                shutil.copytree(comments_path, dest_comment)
                dest_desc = os.path.join(dest, 'desc', good_id)
                shutil.copytree(desc_path, dest_desc)


def get_desc_root_path(good_id):
    DESC_BASE_PATH = 'D:/cdn1/desc'
    DESC_SUB_DIR = os.listdir(DESC_BASE_PATH)
    desc_dir = ''
    for path in DESC_SUB_DIR:
        desc_dir = os.path.join(DESC_BASE_PATH, path, '2', good_id)
        if os.path.exists(desc_dir):
            break
    return desc_dir


def split_image(src_path, dest_path):
    root_img_path = 'D:/cdn1'
    good_ids = os.listdir(src_path)
    dest_comment = os.path.join(dest_path, 'comment')
    dest_desc = os.path.join(dest_path, 'desc')
    splited = os.listdir(dest_comment)
    for good_id in good_ids:
        if good_id in splited:
            continue
        img_path = os.path.join(src_path, good_id)
        img_names = os.listdir(img_path)
        comment_root_path = os.path.join(root_img_path,'comment', good_id)
        dest_comment_path = os.path.join(dest_comment, good_id)
        dest_desc_path = os.path.join(dest_desc, good_id)
        desc_root_path = get_desc_root_path(good_id)
        if not os.path.exists(dest_comment_path):
            os.mkdir(dest_comment_path)
        if not os.path.exists(dest_desc_path):
            os.mkdir(dest_desc_path)
        for name in img_names:
            if name in os.listdir(comment_root_path):
                shutil.copy(os.path.join(img_path, name), dest_comment_path)
            elif name in os.listdir(desc_root_path):
                shutil.copy(os.path.join(img_path, name), dest_desc_path)


if __name__ == '__main__':
    for i in range(1,5):
        print(i)
    # print_img_info('E:/cdn/datasets/triplet')
    # get_img_by_catalog('D:/cdn1', 'E:/cdn/datasets/triplet-95')
    # resize_image('E:/cdn/datasets/triplet-95')
    # split_image('E:/cdn/datasets/triplet', 'E:/cdn/datasets/triplet-95')
    # merge_img_other(catalog=95)
    # img_base = 'E:/cdn/datasets/triplet'
    # print(os.listdir(img_base))
    # selected_path = 'E:\\cdn\\filtered\\useless'
    # move_to = 'E:/cdn/datasets/no_size'
    # filter_image(img_base, selected_path, move_to)

    # resize_image('E:/cdn/comment')
    '''
    model = load_model('E:/unopen_model.h5')
    dir = 'E:\\cdn\\datasets\\triplet\\comment\\15284877275'
    img_names = [path for path in os.listdir(dir)]
    imgs = np.zeros((len(img_names),299,299,3))
    for i,name in enumerate(img_names):
        # img = image.load_img('E:/cdn/datasets/triplet/comment/14566770073/59ed81a1N61e71bb1.jpg',target_size=(299,299))
        img = image.load_img(os.path.join(dir,name),target_size=(299,299))
        x = image.img_to_array(img)
        x = np.expand_dims(x,axis=0)   #这三行的意思是提取 图片特征
        x = preprocess_input(x)
        imgs[i] = x
    preds = model.predict(imgs)
    # img = imread('E:/cdn/datasets/triplet/comment/14276177179/59ae5c4dNcfaec1d7.jpg')
    # img = img.resize(width, width, 3)
    # preds = model.predict(img)
    print(preds)
    
    datagen = ImageDataGenerator(
        featurewise_center=True,
        featurewise_std_normalization=True,
        rotation_range=20,
        width_shift_range=0.2,
        height_shift_range=0.2,
        horizontal_flip=True)
    x_train = np.random.rand(1,5,5,3)
    y_train = [1]
    # y_train = np_utils.to_categorical(y_train,3)
    print(x_train,y_train)
    datagen.fit(x_train)
    batches = 0
    for x_batch,y_batch in datagen.flow(x_train,y_train,batch_size=5):
        batches += 1
        print('------------------------------')
        print(x_batch,y_batch)
        # if batches>3:
        #     break
    print(batches)
    '''
